# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load(
    "@xla//third_party/py:py_import.bzl",
    "py_import",
)
load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
load(
    "//jaxlib:jax.bzl",
    "jax_source_package",
    "jax_wheel",
    "py_deps",
    "pytype_test",
)

collect_data_files(
    name = "transitive_py_data",
    deps = ["//jax"],
)

transitive_py_deps(
    name = "transitive_py_deps",
    deps = [
        "//jax",
        "//jax:compilation_cache",
        "//jax:experimental",
        "//jax:experimental_colocated_python",
        "//jax:experimental_sparse",
        "//jax:lax_reference",
        "//jax:pallas_experimental_gpu_ops",
        "//jax:pallas_gpu_ops",
        "//jax:pallas_mosaic_gpu",
        "//jax:pallas_tpu_ops",
        "//jax:pallas_triton",
        "//jax:source_mapper",
        "//jax:sparse_test_util",
        "//jax:test_util",
        "//jax/_src/lib",
        "//jax/_src/pallas/fuser",
        "//jax/_src/pallas/mosaic_gpu",
        "//jax/experimental/array_serialization:serialization",
        "//jax/experimental/jax2tf",
        "//jax/extend",
        "//jax/extend:ifrt_programs",
        "//jax/extend/mlir",
        "//jax/extend/mlir/dialects",
        "//jax/tools:colab_tpu",
        "//jax/tools:jax_to_ir",
        "//jax/tools:pgo_nsys_converter",
    ],
)

py_binary(
    name = "build_wheel",
    srcs = ["build_wheel.py"],
    deps = [
        "//jaxlib/tools:build_utils",
        "@pypi_build//:pkg",
        "@pypi_setuptools//:pkg",
        "@pypi_wheel//:pkg",
    ],
)

WHEEL_SOURCE_FILES = [
    ":transitive_py_data",
    ":transitive_py_deps",
    "//jax:py.typed",
    "AUTHORS",
    "LICENSE",
    "README.md",
    "pyproject.toml",
    "setup.py",
]

jax_wheel(
    name = "jax_wheel",
    platform_independent = True,
    source_files = WHEEL_SOURCE_FILES,
    wheel_binary = ":build_wheel",
    wheel_name = "jax",
)

jax_wheel(
    name = "jax_wheel_editable",
    editable = True,
    platform_independent = True,
    source_files = WHEEL_SOURCE_FILES,
    wheel_binary = ":build_wheel",
    wheel_name = "jax",
)

jax_source_package(
    name = "jax_source_package",
    source_files = WHEEL_SOURCE_FILES,
    source_package_binary = ":build_wheel",
    source_package_name = "jax",
)

genrule(
    name = "wheel_additives",
    srcs = [
        "//jax:internal_export_back_compat_test_util",
        "//jax:internal_test_harnesses",
        "//jax:internal_test_util",
        "//jax:internal_export_back_compat_test_data",
        "//jax:experimental/pallas/ops/tpu/random/philox.py",
        "//jax:experimental/pallas/ops/tpu/random/prng_utils.py",
        "//jax:experimental/pallas/ops/tpu/random/threefry.py",
        "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
        "//jax/experimental/mosaic/gpu/examples:matmul.py",
    ],
    outs = ["wheel_additives.zip"],
    cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)",
    tools = ["@bazel_tools//tools/zip:zipper"],
)

COMMON_DEPS = py_deps([
    "absl/testing",
    "numpy",
    "ml_dtypes",
    "scipy",
    "opt_einsum",
    "hypothesis",
    "cloudpickle",
    "flatbuffers",
])

py_import(
    name = "jax_py_import",
    wheel = ":jax_wheel",
    wheel_deps = [":wheel_additives"],
    deps = COMMON_DEPS,
)

# This target is used to add more sources to the jax wheel.
# This is needed for the tests that depend on jax and use modules that are not part of
# the jax wheel, but share the same package paths as the modules in the jax wheel.
py_import(
    name = "jax_wheel_with_internal_test_util",
    wheel = "@pypi_jax//:whl",
    wheel_deps = [":wheel_additives"],
    deps = COMMON_DEPS,
)

pytype_test(
    name = "jax_wheel_size_test",
    srcs = ["//jaxlib/tools:wheel_size_test.py"],
    args = [
        "--wheel-path=$(location :jax_wheel)",
        "--max-size-mib=5",
    ],
    data = [":jax_wheel"],
    main = "wheel_size_test.py",
    tags = [
        "manual",
        "notap",
    ],
)
